{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n",
      "(3, 3, 192, 192) (3, 6, 192, 192)\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAd8AAAKvCAYAAAArysUEAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3X+sXHd95//ne53CH2lWCc1dy3Jymx8yIKi2F7jyFtGgsCltEuVLSCOlsSoINOpN9CXSVlup30CqgipF6rak0VbdBm4UK2YFJrRuSlT525KNUEOrpOQavMYBAnbqKLaMfUlYEgEKdfzeP+65dLie6zsz58xnZs48H9JoznzmnDnv4/j4lc/nnPlMZCaSJKmcfzfqAiRJmjaGryRJhRm+kiQVZvhKklSY4StJUmGGryRJhQ0tfCPi6oh4JiIORcSdw9qPJEmTJobxPd+I2AR8C3g3cBR4CtiRmV9vfGeSJE2YYfV8twOHMvPZzPwx8Fng+iHtS5KkiXLOkD53K/B8x+ujwH9ab+ULL7wwL7nkkiGVIo23ffv2fTczZ0ZdR1M8nzXNej2fhxW+G4qIBWABYHZ2lqWlpVGVIo1URDw36hrq8nyWVvR6Pg9r2PkYcHHH64uqtp/IzMXMnM/M+ZmZ1vxPvzSVPJ+l/gwrfJ8CtkXEpRHxGuBm4JEh7UuSpIkylGHnzDwVEXcAfw9sAnZm5tPD2JckSZNmaNd8M3MvsHdYny9J0qRyhitJkgozfCVJKszwlSSpMMNXkqTCDF9JkgozfCVJKszwlSSpMMNXkqTCDF9JkgozfCVJKszwlSSpMMNXkqTCDF9JkgozfCVJKszwlSSpMMNXkqTCBg7fiLg4Ir4YEV+PiKcj4r9U7R+LiGMRsb96XNtcuZIkTb5zamx7CvjdzPxKRJwH7IuIR6v37s3Mj9cvT5Kk9hk4fDPzOHC8Wn45Ir4BbG2qMEmS2qqRa74RcQnwFuCfq6Y7IuJAROyMiAua2IckSW1RO3wj4meBPcDvZOZLwH3A5cAcKz3je9bZbiEiliJiaXl5uW4ZkkbI81nqT63wjYifYSV4P52Zfw2QmScy89XMPA3cD2zvtm1mLmbmfGbOz8zM1ClD0oh5Pkv9qXO3cwAPAN/IzD/taN/SsdoNwMHBy5MkqX3q3O38DuB9wNciYn/V9hFgR0TMAQkcAW6rVaEkSS1T527nfwSiy1t7By+njIggM0ddhqQGxCevI2/721GXIfVlame4Whk1l9QG8cnrRl2C1JepDV8wgKU2MYA1SaY6fMEAltrEANakmPrwBQNYahMDWJPA8K0YwFJ7GMAad4ZvBwNYag8DWOPM8F3DAJbawwDWuDJ8uzCApfYwgDWODN91GMBSexjAGjeG71kYwFJ7GMAaJ3Xmdp5YTi0ptYdTS2oS2fOVJKkww1eSpMIMX0mSCjN8JUkqbCpvuJIk9eftl9/d9zZPHL5rCJW0Q+3wjYgjwMvAq8CpzJyPiNcBDwGXAEeAmzLze3X3pckTEWRmz8+SNA2aGnZ+V2bOZeZ89fpO4LHM3AY8Vr3WFFoN1F6fJWkaDOua7/XArmp5F/DeIe1HY251opJenyVpGjQRvgl8ISL2RcRC1bY5M49Xy98BNjewH00ge76SdKYmbrj65cw8FhH/AXg0Ir7Z+WZmZkSc8S9rFdQLALOzsw2UoXHkNd/p4Pks9ad2zzczj1XPJ4GHge3AiYjYAlA9n+yy3WJmzmfm/MzMTN0yNKbs+U4Hz2epP7XCNyLOjYjzVpeBXwUOAo8At1Sr3QJ8vs5+NLm85itJZ6o77LwZeLj6h/Mc4DOZ+XcR8RTwuYi4FXgOuKnmfjSh7PlK0plqhW9mPgv8Ypf2F4Cr6ny22sFrvpJ0JqeX1FDZ85WkMxm+Giqv+UrSmQxfDZU9X0k6kz+sIEnakD+S0Cx7vpIkFWb4SpJUmOErSVJhhq8kSYUZvpIkFWb4SpJUmOErSVJhhq8kSYUZvpIkFWb4SpJUmOErSVJhhq8kSYUN/MMKEfEG4KGOpsuAPwDOB34bWK7aP5KZeweuUJKklhk4fDPzGWAOICI2AceAh4EPAvdm5scbqVCSpJZpatj5KuBwZj7X0OdJktRaTYXvzcDujtd3RMSBiNgZERc0tA9JklqhdvhGxGuA9wB/WTXdB1zOypD0ceCedbZbiIiliFhaXl7utoqkCeH5LPWniZ7vNcBXMvMEQGaeyMxXM/M0cD+wvdtGmbmYmfOZOT8zM9NAGZJGxfNZ6k8T4buDjiHniNjS8d4NwMEG9iFJUmsMfLczQEScC7wbuK2j+Y8jYg5I4Mia9yRJmnq1wjczfwD83Jq299WqSJKklnOGK0mSCjN8JUkqzPCVJKkww1eSpMIMX0mSCjN8JUkqzPCVJKkww1eSpMIMX0mSCjN8JUkqzPCVJKkww1eSpMIMX0mSCjN8JUkqzPCVJKkww1eSpMLO6WWliNgJXAeczMxfqNpeBzwEXAIcAW7KzO9FRAD/HbgW+CHwgcz8SvOlt8PKH1dvMnOIlUiqa2HPsZ7XXbxx6xAr0bjrtef7IHD1mrY7gccycxvwWPUa4BpgW/VYAO6rX2b7RERfwTvoNpKGb2HPsb6Cd9Bt1B49hW9mPg68uKb5emBXtbwLeG9H+6dyxZPA+RGxpYli26JugBrA0vioG6AG8HTqadh5HZsz83i1/B1gc7W8FXi+Y72jVdtxplyTobn6WQ5FS6PRZGiufpZD0dOjTvj+RGZmRPSVAhGxwMqwNLOzs02UMdY2Ct6zhejZto0IA1gjN23n80bBe7YQPdu2C3uOGcBTok74noiILZl5vBpWPlm1HwMu7ljvoqrtp2TmIrAIMD8/P7Xp0Utwrq7jcLPGlefzil6Cc3Udh5unW52vGj0C3FIt3wJ8vqP9/bHil4DvdwxPT6X1QrPfHut66xvKUjnrhWa/Pdb11jeUp0NP4RsRu4EngDdExNGIuBX4I+DdEfFt4Feq1wB7gWeBQ8D9wP/beNUTpKng3Wg7A1gavqaCd6PtDOD262nYOTN3rPPWVV3WTeBDdYpqu7rXaDPTsJXGRN1rtIs3bjVsp5AzXA1Rt4Bs6uaobp9jIEvD0y0gm7o5qtvnGMjtZvhKklSY4VtQ018J8itG0ug0/ZUgv2I0XQxfSZIKM3wlSSqskRmupBL8BSipPb78jnf1vO72f/riECsZDXu+kiQVZvhKklSY4StJUmGGb0FNT4LhpBrS6DQ9CYaTakwXw1eSpMIM3yEa5hSQw5y6UtKZhjkF5DCnrtR4MnxHoG4AO9wsjY+6Aexw83QyfIes6Z8AbPonCiX1rumfAGz6Jwo1OQzfApoKYINXGr2mAtjgnW7OcDViq4F6tgB1mFmaDKuBerYAdZhZ0EP4RsRO4DrgZGb+QtX2J8D/A/wYOAx8MDP/T0RcAnwDeKba/MnMvH0IdU+czDxriA4asNPU652mY9V4W7xx61lDdNCAnaZebxunjOxHLz3fB4E/Bz7V0fYo8OHMPBUR/w34MPD/Ve8dzsy5RqtsidXwaKInaxBJo7UalE30ZKcpdLViw2u+mfk48OKati9k5qnq5ZPARUOorbXqBqfBK42PusFp8E6nJq75/hbwUMfrSyPiq8BLwO9n5pca2EfrDNILNnSl8TRIL9jQnW61wjci7gJOAZ+umo4Ds5n5QkS8DfibiHhzZr7UZdsFYAFgdna2ThkTzUBVG3g+rzBQ1auBv2oUER9g5Uas38wqQTLzlcx8oVrex8rNWK/vtn1mLmbmfGbOz8zMDFqGpDHg+Sz1Z6DwjYirgd8D3pOZP+xon4mITdXyZcA24NkmCpUkqS16+arRbuBK4MKIOAp8lJW7m18LPFpds1z9StE7gT+MiH8FTgO3Z+aLXT9YkqQptWH4ZuaOLs0PrLPuHmBP3aIkSWozp5eUJKkww1eSpMIMX0mSCjN8JUkqzPCVJKkww1eSpML8PV9NvW7zazvtpzSZXt5/Zp/yvLnTI6jk7Oz5aqqt98MWTfzso6SyugXv2dpHafwqkgrZKGANYGlybBSw4xbA41WNJElTwPCVJKkww1eSpMIMX0mSCjN8NbU2+jqRXzeSJsdGXycat68bGb6aausFrMErTZ71AnbcghecZENjZvXrPSXDz6CVhmPXl94OwC1XPFFsn+MYtN1s2PONiJ0RcTIiDna0fSwijkXE/upxbcd7H46IQxHxTET82rAKlyRpUvUy7PwgcHWX9nszc6567AWIiDcBNwNvrrb5i4jY1FSxkiS1wYbhm5mPAy/2+HnXA5/NzFcy81+AQ8D2GvVJktQ6dW64uiMiDlTD0hdUbVuB5zvWOVq1SZKkyqDhex9wOTAHHAfu6fcDImIhIpYiYml5eXnAMiSNA89nqT8DhW9mnsjMVzPzNHA//za0fAy4uGPVi6q2bp+xmJnzmTk/MzMzSBmSxoTns9SfgcI3IrZ0vLwBWL0T+hHg5oh4bURcCmwDvlyvREmS2mXD7/lGxG7gSuDCiDgKfBS4MiLmgASOALcBZObTEfE54OvAKeBDmfnqcEqXJGkybRi+mbmjS/MDZ1n/buDuOkWpnfr5fdxe1nVyDGl0VifQaGrdkhNxjAOnl5QkqTCnl1QxvfRURzG9pKT+9dJTHcX0kpPCnq8kSYUZvpIkFWb4SpJUmOErSVJhhq8kSYUZvpIkFWb4SpJUmOErSVJhTrKhseLkGlJ7OLnG+uz5SpJUmOErSVJhhq8kSYUZvpIkFWb4SpJU2IbhGxE7I+JkRBzsaHsoIvZXjyMRsb9qvyQiftTx3ieGWbwkSZOol68aPQj8OfCp1YbM/I3V5Yi4B/h+x/qHM3OuqQIlSWqbDcM3Mx+PiEu6vRcrv3x+E/Cfmy1LkqT2qnvN9wrgRGZ+u6Pt0oj4akT8Q0RcUfPzJUlqnbozXO0Adne8Pg7MZuYLEfE24G8i4s2Z+dLaDSNiAVgAmJ2drVmGpFHyfJb6M3DPNyLOAX4deGi1LTNfycwXquV9wGHg9d22z8zFzJzPzPmZmZlBy5A0Bjyfpf7UGXb+FeCbmXl0tSEiZiJiU7V8GbANeLZeiZIktUsvXzXaDTwBvCEijkbErdVbN/PTQ84A7wQOVF89+ivg9sx8scmCJUmadL3c7bxjnfYPdGnbA+ypX5YkSe3lDFeSJBVm+EqSVJjhK0lSYYavJEmFGb6SJBVm+EqSVJjhK0lSYYavJEmFGb6SJBVm+EqSVJjhK0lSYZGZo66BiFgGfgB8d9S1NOBCPI5xMgnH8fOZ2Zrf4YuIl4FnRl1HAybh704vPI6yejqfxyJ8ASJiKTPnR11HXR7HeGnLcUyStvyZexzjpS3HscphZ0mSCjN8JUkqbJzCd3HUBTTE4xgvbTmOSdKWP3OPY7y05TiAMbrmK0nStBinnq8kSVPB8JUkqTDDV5KkwgxfSZIKM3wlSSrM8JUkqTDDV5KkwgxfSZIKM3wlSSrM8JUkqTDDV5KkwgxfSZIKM3wlSSrM8JUkqTDDV5KkwgxfSZIKM3wlSSrM8JUkqTDDV5KkwgxfSZIKM3wlSSrM8JUkqTDDV5KkwgxfSZIKM3wlSSrM8JUkqTDDV5KkwgxfSZIKM3wlSSpsaOEbEVdHxDMRcSgi7hzWfiRJmjSRmc1/aMQm4FvAu4GjwFPAjsz8euM7kyRpwgyr57sdOJSZz2bmj4HPAtcPaV+SJE2UYYXvVuD5jtdHqzZJkqbeOaPacUQsAAsA55577tve+MY3jqoUaaT27dv33cycGXUddXg+Syt6PZ+HFb7HgIs7Xl9Utf1EZi4CiwDz8/O5tLQ0pFKk8RYRz426hro8n6UVvZ7Pwxp2fgrYFhGXRsRrgJuBR4a0L0mSJspQer6ZeSoi7gD+HtgE7MzMp4exL0mSJs3Qrvlm5l5g77A+X5KkSeUMV5IkFWb4SpJUmOErSVJhhq8kSYUZvpIkFWb4SpJUmOErSVJhhq8kSYUZvpIkFWb4SpJUmOErSVJhhq8kSYUZvpIkFWb4SpJUmOErSVJhhq8kSYUNHL4RcXFEfDEivh4RT0fEf6naPxYRxyJif/W4trlyJUmafOfU2PYU8LuZ+ZWIOA/YFxGPVu/dm5kfr1+eJEntM3D4ZuZx4Hi1/HJEfAPY2lRhkiS1VSPXfCPiEuAtwD9XTXdExIGI2BkRFzSxD0mS2qJ2+EbEzwJ7gN/JzJeA+4DLgTlWesb3rLPdQkQsRcTS8vJy3TIkjZDns9SfWuEbET/DSvB+OjP/GiAzT2Tmq5l5Grgf2N5t28xczMz5zJyfmZmpU4akEfN8lvpT527nAB4AvpGZf9rRvqVjtRuAg4OXJ0lS+9S52/kdwPuAr0XE/qrtI8COiJgDEjgC3FarQkmSWqbO3c7/CESXt/YOXo4kSe3nDFeSJBVm+EqSVJjhK0lSYYavJEmFGb6SJBVm+EqSVJjhK0lSYYavJEmFGb6SJBVm+EqSVJjhK0lSYYavJEmFGb6SJBVm+EqSVJjhq4kSEUR0+yVLSZNm15fezq4vvX3UZYyE4StJUmHn1P2AiDgCvAy8CpzKzPmIeB3wEHAJcAS4KTO/V3dfkiS1QVM933dl5lxmzlev7wQey8xtwGPVa0mSxPCGna8HdlXLu4D3Dmk/kiRNnCbCN4EvRMS+iFio2jZn5vFq+TvA5gb2I0lSK9S+5gv8cmYei4j/ADwaEd/sfDMzMyJy7UZVUC8AzM7ONlCG2mSjO5rXez/zjL9qKsDzWWez0R3N671/yxVPDKOcsVC755uZx6rnk8DDwHbgRERsAaieT3bZbjEz5zNzfmZmpm4ZkkbI81nqT62eb0ScC/y7zHy5Wv5V4A+BR4BbgD+qnj9ft1BNl/V6sKs9Xnu40uRYrwe72uNtcw93PXWHnTcDD1f/IJ4DfCYz/y4ingI+FxG3As8BN9XcjyRJrVErfDPzWeAXu7S/AFxV57MlSWorZ7iSJKkww1eSpMIMX0mSCmvie75SMd7lLLXHNN7lvMqeryRJhRm+kiQVZvhKklSY4StJUmGGryRJhRm+kiQVZvhKklSY4StJUmGGryRJhRm+kiQVZvhKklSY4StJUmED/7BCRLwBeKij6TLgD4Dzgd8Glqv2j2Tm3oErlCSpZQYO38x8BpgDiIhNwDHgYeCDwL2Z+fFGKpQkqWWaGna+Cjicmc819HmSJLVWU+F7M7C74/UdEXEgInZGxAUN7UOSpFaoHb4R8RrgPcBfVk33AZezMiR9HLhnne0WImIpIpaWl5e7rSJpQng+S/1poud7DfCVzDwBkJknMvPVzDwN3A9s77ZRZi5m5nxmzs/MzDRQhqRR8XyW+tNE+O6gY8g5IrZ0vHcDcLCBfUiS1BoD3+0MEBHnAu8Gbuto/uOImAMSOLLmPUmSpl6t8M3MHwA/t6btfbUqkiSp5ZzhSpKkwgxfSZIKM3wlSSrM8JUkqTDDV5KkwgxfSZIKM3wlSSrM8JUkqTDDV5KkwmrNcKVyIoLMrP0safTefvndtT/jicN3NVCJRsWe74RYDc66z5Kk0TN8J0RENPIsSRo9w3dC2POVpPYwfCeEPV9Jag/Dd0LY85Wk9jB8J4Q9X0lqj57CNyJ2RsTJiDjY0fa6iHg0Ir5dPV9QtUdE/FlEHIqIAxHx1mEVP03s+UpSe/Ta830QuHpN253AY5m5DXiseg1wDbCteiwA99UvU/Z8Jak9egrfzHwceHFN8/XArmp5F/DejvZP5YongfMjYksTxU4ze76S1B51ZrjanJnHq+XvAJur5a3A8x3rHa3ajqOBtWmGq3564eNSs6TuvvyOd/W87vZ/+uIQK5ksjdxwlSv/Qvb1r2RELETEUkQsLS8vN1FGq9nz1TjzfJb6Uyd8T6wOJ1fPJ6v2Y8DFHetdVLX9lMxczMz5zJyfmZmpUcZ08Jqvxpnns9SfOuH7CHBLtXwL8PmO9vdXdz3/EvD9juFpDcieryS1R0/XfCNiN3AlcGFEHAU+CvwR8LmIuBV4DripWn0vcC1wCPgh8MGGa55KbbrmK0nTrqfwzcwd67x1VZd1E/hQnaJ0Jnu+ktQeznA1IbzmK0ntUeerRirInq/UHk8cvmvUJWjE7PlKklSY4StJUmGGryRJhXnNV8V5/VlqD6eMHIw9X0mSCjN8JUkqzPCVJKkww1eSpMIMX0mSCjN8JUkqzPCVJKkww1eSpMIMX0mSCjN8JUkqbMPwjYidEXEyIg52tP1JRHwzIg5ExMMRcX7VfklE/Cgi9lePTwyzeEmSJlEvPd8HgavXtD0K/EJm/kfgW8CHO947nJlz1eP2ZsqUJKk9NgzfzHwceHFN2xcy81T18kngoiHUJklSKzVxzfe3gP+/4/WlEfHViPiHiLiigc+XJKlVav2kYETcBZwCPl01HQdmM/OFiHgb8DcR8ebMfKnLtgvAAsDs7GydMiSNmOez1J+Be74R8QHgOuA3s/qB1sx8JTNfqJb3AYeB13fbPjMXM3M+M+dnZmYGLUPSGPB8lvozUPhGxNXA7wHvycwfdrTPRMSmavkyYBvwbBOFSpLUFhsOO0fEbuBK4MKIOAp8lJW7m18LPBoRAE9Wdza/E/jDiPhX4DRwe2a+2PWDJUmaUhuGb2bu6NL8wDrr7gH21C1KkqQ2c4YrSZIKM3wlSSrM8JUkqTDDV5KkwgxfSZIKM3wlSSrM8JUkqTDDV5KkwgxfSZIKM3wlSSrM8JUkqbBav+c7aaofgeiq+lVESRPi5f3r9x3OmztdsBKpf1MRvmcL3bXrGMLSeDtb6K5dxxDWuGr9sHMvwVtnfUnl9BK8ddaXSvFvpiRJhbU6fAftxdr7lcbPoL1Ye78aRxv+rYyInRFxMiIOdrR9LCKORcT+6nFtx3sfjohDEfFMRPzasAqXJGlS9fK/hA8CV3dpvzcz56rHXoCIeBNwM/Dmapu/iIhNTRUrSVIbbBi+mfk48GKPn3c98NnMfCUz/wU4BGyvUZ8kSa1T52LIHRFxoBqWvqBq2wo837HO0apNkiRVBg3f+4DLgTngOHBPvx8QEQsRsRQRS8vLywOWIWkceD5L/RkofDPzRGa+mpmngfv5t6HlY8DFHateVLV1+4zFzJzPzPmZmZlBypA0Jjyfpf4MFL4RsaXj5Q3A6p3QjwA3R8RrI+JSYBvw5XolDm7Q2aqmeZYrv2alcTXobFXTPMtVfPK6UZegdWw4vWRE7AauBC6MiKPAR4ErI2IOSOAIcBtAZj4dEZ8Dvg6cAj6Uma8Op3QNS0RM9f+ASG0Sn7yOvO1vR12G1tgwfDNzR5fmB86y/t3A3XWKalJm9tWbM3RWGMAaR+fNne5r0oxp7vV2MoDHz1RM/ZKZGwZJL+tMG4egNY7Omzu9Yaj2ss60cQh6vEzFrxqtMlz7Zw9Y48pw7Z894PExFT1f1WMPWGoPe8DjwfBVTwxgqT0M4NEzfNUzA1hqDwN4tKbqmq/qG+U14F7C3+vTUu9GeQ14YU/X+Zd+yuKN7Z2d2PBV30oHcD897tV1DWGpN6UDuJfQXbtuG0PYYWcNpNQQ9KD7cYhc6l2pIeh+greJ7caZPV8NbNg94PUCtNs+u63r16Sk3g27B7xegHbr1XZbd2HPsVb1gO35qpZh9TC7fe7ZJkJZ7z17wFLvhtUD7hamizduXTdM13uvTT1gw1e1NR1w6wVvLwxgqZ6mA3i94O1FmwPY8FUjhhlw/Q4dO9Qs1TPMa8D9Dh23aai5k+GrxjQRwGs/o6mfhbT3K/WniQBe20sdNEjXbteG3q/hq0YZclJ7OBHH8Bi+alxTAVx3+NjhZ6m+pgK47vBx24afDV8NhT1gqT3sATdvw/CNiJ0RcTIiDna0PRQR+6vHkYjYX7VfEhE/6njvE8MsXuPNAJbawwBuVi893weBqzsbMvM3MnMuM+eAPcBfd7x9ePW9zLy9uVI1iQxgqT0M4OZsGL6Z+TjwYrf3YuVf1puA3Q3XpRYxgKX2MICbUXd6ySuAE5n57Y62SyPiq8BLwO9n5pdq7kMFjdNNSnWnhzT0Ne1G9YtF3dSdHrINXy/qVDd8d/DTvd7jwGxmvhARbwP+JiLenJkvrd0wIhaABYDZ2dmaZUgaJc9nqT8D3+0cEecAvw48tNqWma9k5gvV8j7gMPD6bttn5mJmzmfm/MzMzKBlqGWamhyjqck61BvPZ3XT1OQYTU3WMU7qfNXoV4BvZubR1YaImImITdXyZcA24Nl6JWra9RvADjdL46vfAG7bcPOqXr5qtBt4AnhDRByNiFurt27mzBut3gkcqL569FfA7ZnZ9WYtaT11fhyhzo8ySGpenR9HqPOjDONuw2u+mbljnfYPdGnbw8pXj6RaMvOMIF193U84G7zS6C3euPWMIF193U84tyV4of4NV9LQdAtg6L0XbPBK46NbAEPvveA2BS84vaTGXFO/aiRp9Jr6VaM2sOersbcapL30eA1dabytBmkvPd42hu4qw1cTw2CV2qPNwdoLh50lSSrM8JUkqTDDV5KkwgxfSZIKM3wlSSrM8JUkqTDDV5KkwgxfSZIKi3GYuCAiloEfAN8ddS0NuBCPY5xMwnH8fGa25kdwI+Jl4JlR19GASfi70wuPo6yezuexCF+AiFjKzPlR11GXxzFe2nIck6Qtf+Yex3hpy3GscthZkqTCDF9Jkgobp/BdHHUBDfE4xktbjmOStOXP3OMYL205DmCMrvlKkjQtxqnnK0nSVDB8JUkqzPCVJKkww1eSpMIMX0mSCjN8JUkqzPCVJKkww1eSpMIMX0mSCjN8JUkqzPCVJKkww1eSpMIMX0mSCjN8JUkqzPCVJKkww1eSpMIMX0mSCjN8JUkqzPCVJKkww1eSpMIMX0mSCjN8JUkqzPCVJKkww1eSpMIMX0mSCjN8JUkqzPCVJKkww1eSpMIMX0mSChta+EbE1RHxTEQciog7h7UfSZImTWRm8x8asQn4FvBu4CjwFLAjM7/e+M4kSZoww+r5bgcOZeazmflj4LPA9UPalyRJE+WcIX3uVuD5jtdHgf/UuUJELAALAOeee+7b3vjGNw6pFGm87du377uZOTPqOurwfJZW9HoA/shgAAAQwUlEQVQ+Dyt8N5SZi8AiwPz8fC4tLY2qFGmkIuK5UddQl+eztKLX83lYw87HgIs7Xl9UtUmSNPWGFb5PAdsi4tKIeA1wM/DIkPYlSdJEGcqwc2aeiog7gL8HNgE7M/PpYexLkqRJM7Rrvpm5F9g7rM+XJGlSOcOVJEmFGb6SJBVm+EqSVJjhK0lSYYavJEmFGb6SJBVm+EqSVJjh25CIICJGXYakBuz60tvZ9aW3j7oMtZjhK0lSYYavJEmFGb6SJBVm+EqSVJjhK0lSYYavJEmFDe0nBduol68SnW2dzGyyHEk19PJVorOtc8sVTzRZjqaMPV9JkgobuOcbERcDnwI2AwksZuZ/j4iPAb8NLFerfiQz99YtdBycree62uO1dytNhrP1XFd7vPZuNSx1hp1PAb+bmV+JiPOAfRHxaPXevZn58frlSZLUPgOHb2YeB45Xyy9HxDeArU0VJklSWzVyw1VEXAK8Bfhn4B3AHRHxfmCJld7x97psswAsAMzOzjZRhgYQEWRm7WdNN8/n8fD2y++u/RlPHL6rgUq0kdo3XEXEzwJ7gN/JzJeA+4DLgTlWesb3dNsuMxczcz4z52dmZuqWoQGtBmfdZ003z2epP7XCNyJ+hpXg/XRm/jVAZp7IzFcz8zRwP7C9fpkaltUbxeo+S5J6N3D4xsq/ug8A38jMP+1o39Kx2g3AwcHL07DZ85Wk8upc830H8D7gaxGxv2r7CLAjIuZY+frREeC2WhVOiEkNIa/5SmfyK0Yatjp3O/8j0G3MsRXf6Z0W9nwlqTxnuJpyXvOVpPIM3ylnz1eSyjN8p5w9X0kqz/CdcvZ8Jak8w3fK2fOVpPIM3ylnz1eSyjN8p5w9X0kqz/CdcvZ8Jak8w3fK2fOVpPIM3ylnz1eSyjN8p5w9X0kqr84PK6gF7PlK7fHE4btGXYJ6ZM9XkqTCDF9JkgozfCVJKszwlSSpsNo3XEXEEeBl4FXgVGbOR8TrgIeAS4AjwE2Z+b26+5IkqQ2a6vm+KzPnMnO+en0n8FhmbgMeq15LkiSGN+x8PbCrWt4FvHdI+5EkaeI0Eb4JfCEi9kXEQtW2OTOPV8vfATav3SgiFiJiKSKWlpeXGyhD0qh4Pkv9aSJ8fzkz3wpcA3woIt7Z+WauzMJwxkwMmbmYmfOZOT8zM9NAGZJGxfNZ6k/t8M3MY9XzSeBhYDtwIiK2AFTPJ+vuR5KktqgVvhFxbkSct7oM/CpwEHgEuKVa7Rbg83X2I0lSm9T9qtFm4OFqcv1zgM9k5t9FxFPA5yLiVuA54Kaa+5EkqTVqhW9mPgv8Ypf2F4Cr6ny2JElt5QxXkiQVZvhKklSY4StJUmG153bWZKtuljurla9qSxp3X37HuzZcZ/s/fbFAJdqIPV9JkgozfCVJKszwlSSpMMNXkqTCDF9JkgozfCVJKszwlSSpMMNXkqTCnGRjyjmBhtQeTqAxOez5SpJUmOErSVJhhq8kSYUNfM03It4APNTRdBnwB8D5wG8Dy1X7RzJz78AVSpLUMgOHb2Y+A8wBRMQm4BjwMPBB4N7M/HgjFUqS1DJNDTtfBRzOzOca+jxJklqrqfC9Gdjd8fqOiDgQETsj4oJuG0TEQkQsRcTS8vJyt1UkTQjPZ6k/tcM3Il4DvAf4y6rpPuByVoakjwP3dNsuMxczcz4z52dmZuqWIWmEPJ+l/jTR870G+EpmngDIzBOZ+WpmngbuB7Y3sA9JklqjifDdQceQc0Rs6XjvBuBgA/uQJKk1ak0vGRHnAu8Gbuto/uOImAMSOLLmPUmSpl6t8M3MHwA/t6btfbUqkiSp5ZzhSpKkwgxfSZIKM3wlSSrM8JUkqTDDV5KkwgxfSZIKM3wlSSrM8JUkqTDDV5KkwgxfSZIKM3wlSSrM8JUkqTDDV5KkwgxfSZIKM3wlSSrM8JUkqbCewjcidkbEyYg42NH2uoh4NCK+XT1fULVHRPxZRByKiAMR8dZhFS9J0iTqtef7IHD1mrY7gccycxvwWPUa4BpgW/VYAO6rX6YkSe3RU/hm5uPAi2uarwd2Vcu7gPd2tH8qVzwJnB8RW5ooVpKkNqhzzXdzZh6vlr8DbK6WtwLPd6x3tGr7KRGxEBFLEbG0vLxcowxJo+b5LPWnkRuuMjOB7HObxcycz8z5mZmZJsqQNCKez1J/6oTvidXh5Or5ZNV+DLi4Y72LqjZJkkS98H0EuKVavgX4fEf7+6u7nn8J+H7H8LQkSVPvnF5WiojdwJXAhRFxFPgo8EfA5yLiVuA54KZq9b3AtcAh4IfABxuuWZKkidZT+GbmjnXeuqrLugl8qE5RkiS1mTNcSZJUmOErSVJhhq8kSYUZvpIkFWb4SpJUmOErSVJhhq8kSYUZvpIkFWb4SpJUmOErSVJhhq8kSYUZvpIkFWb4SpJUmOErSVJhhq8kSYUZvpIkFbZh+EbEzog4GREHO9r+JCK+GREHIuLhiDi/ar8kIn4UEfurxyeGWbwkSZOol57vg8DVa9oeBX4hM/8j8C3gwx3vHc7MuepxezNlKiJGXYKkhsQnrxt1CRqxDcM3Mx8HXlzT9oXMPFW9fBK4aAi1aQ0DWGoPA3i6ndPAZ/wW8FDH60sj4qvAS8DvZ+aXum0UEQvAAsDs7GwDZYyXXoIyMwf63EG2k4ap7efzwp5jG66zeOPWvj83PnkdedvfDlKSJlytG64i4i7gFPDpquk4MJuZbwH+K/CZiPj33bbNzMXMnM/M+ZmZmTpljJ1ee6iD9mTtAWvctPl87iV4+1lvLXvA02ng8I2IDwDXAb+ZVVcsM1/JzBeq5X3AYeD1DdQ5MboFY2b+5NHL+oPuR1KzugXq4o1bf/LoZf1eGMDTZ6DwjYirgd8D3pOZP+xon4mITdXyZcA24NkmCp0EawOxW+B2azOApfGzNki7BW63NgNYvejlq0a7gSeAN0TE0Yi4Ffhz4Dzg0TVfKXoncCAi9gN/BdyemS92/eCW6Ra8Z2MAS+OrW/CejQGsfm14w1Vm7ujS/MA66+4B9tQtatL1ekNUZjYSnt6EJQ1PrzdSLd64deDQ7eRNWNPBGa5awh6w1B72gNvP8G1Yvz3QJnusBrDUrH6/PjTI143WYwC3m+HbMgaw1B4GcHsZvi1kAEvtYQC3UxMzXKlDvzc/9RqU3lAllbew51hfQ8m93nDlDVWy5ytJUmGG7xAMe3pJSeUMe3pJTSfDtyH9TprR76Qcksrpd9KMfiflkAzfBnUL4LUh263N4JXGT7cAXhuy3doMXvXCG64a1m3WqrP1gg1eaXx1m7XqbL1gg1e9suc7BP1MLylpvPUzvaTUK3u+Q2KwSu1hsKpp9nwlSSrM8JUkqTDDV5KkwgxfSZIK2zB8I2JnRJyMiIMdbR+LiGMRsb96XNvx3ocj4lBEPBMRvzaswiVJmlS99HwfBK7u0n5vZs5Vj70AEfEm4GbgzdU2fxERm5oqVpKkNtgwfDPzceDFHj/veuCzmflKZv4LcAjYXqM+SZJap8413zsi4kA1LH1B1bYVeL5jnaNV2xkiYiEiliJiaXl5uUYZkkbN81nqz6Dhex9wOTAHHAfu6fcDMnMxM+czc35mZmbAMiSNA89nqT8DhW9mnsjMVzPzNHA//za0fAy4uGPVi6o2SZJUGSh8I2JLx8sbgNU7oR8Bbo6I10bEpcA24Mv1SpQkqV02nNs5InYDVwIXRsRR4KPAlRExByRwBLgNIDOfjojPAV8HTgEfysxXh1O6JEmTacPwzcwdXZofOMv6dwN31ylKkqQ2c4YrSZIKM3wlSSrM3/OVgIjYcB1/o1maDC/v37hfed7c6QKVrM+er6ZeL8Hbz3qSRqeX4O1nvWGx56upNUiYrm5jL1gaL4OE6eo2o+gF2/OVJKkww1dTqe4QskPQ0vioO4Q8iiFow1dTp6ngNICl0WsqOEsHsOErSVJhhq8kSYUZvpIkFWb4SpJUmOErSVJhhq8kSYUZvpIkFWb4SpJU2IbhGxE7I+JkRBzsaHsoIvZXjyMRsb9qvyQiftTx3ieGWbw0iKbmZXZ+Z2n0mpqXufT8zr38sMKDwJ8Dn1ptyMzfWF2OiHuA73esfzgz55oqUBqGzKw1Q5XBK42P8+ZO15qhahQ/rLBh+Gbm4xFxSbf3YuVfr5uA/9xsWZIktVfda75XACcy89sdbZdGxFcj4h8i4or1NoyIhYhYioil5eXlmmVI/cvMvnuwg2wzDTyfNWrnzZ3uuwc7yDZNqRu+O4DdHa+PA7OZ+RbgvwKfiYh/323DzFzMzPnMnJ+ZmalZhjS4XsPU0F2f57PGRa9hOqrQXdXLNd+uIuIc4NeBt622ZeYrwCvV8r6IOAy8HliqWac0VAar1B6jDtZe1On5/grwzcw8utoQETMRsalavgzYBjxbr0RJktqll68a7QaeAN4QEUcj4tbqrZv56SFngHcCB6qvHv0VcHtmvthkwZIkTbpe7nbesU77B7q07QH21C9LkqT2coYrSZIKM3wlSSrM8JUkqTDDV5KkwgxfSZIKM3wlSSrM8JUkqTDDV5KkwgxfSZIKM3wlSSrM8JUkqbAYh59Si4hl4AfAd0ddSwMuxOMYJ5NwHD+fma35EdyIeBl4ZtR1NGAS/u70wuMoq6fzeSzCFyAiljJzftR11OVxjJe2HMckacufuccxXtpyHKscdpYkqTDDV5KkwsYpfBdHXUBDPI7x0pbjmCRt+TP3OMZLW44DGKNrvpIkTYtx6vlKkjQVRh6+EXF1RDwTEYci4s5R19OPiDgSEV+LiP0RsVS1vS4iHo2Ib1fPF4y6zrUiYmdEnIyIgx1tXeuOFX9W/fc5EBFvHV3lP22d4/hYRByr/pvsj4hrO977cHUcz0TEr42m6nbzfC7P83kyz+eRhm9EbAL+B3AN8CZgR0S8aZQ1DeBdmTnXcQv8ncBjmbkNeKx6PW4eBK5e07Ze3dcA26rHAnBfoRp78SBnHgfAvdV/k7nM3AtQ/b26GXhztc1fVH//1BDP55F5EM/niTufR93z3Q4cysxnM/PHwGeB60dcU13XA7uq5V3Ae0dYS1eZ+Tjw4prm9eq+HvhUrngSOD8itpSp9OzWOY71XA98NjNfycx/AQ6x8vdPzfF8HgHP58k8n0cdvluB5zteH63aJkUCX4iIfRGxULVtzszj1fJ3gM2jKa1v69U9if+N7qiG1HZ2DBNO4nFMmkn/M/Z8Hk+tPJ9HHb6T7pcz862sDOV8KCLe2flmrtxKPnG3k09q3ZX7gMuBOeA4cM9oy9EE8XweP609n0cdvseAizteX1S1TYTMPFY9nwQeZmXY48TqME71fHJ0FfZlvbon6r9RZp7IzFcz8zRwP/82FDVRxzGhJvrP2PN5/LT5fB51+D4FbIuISyPiNaxcQH9kxDX1JCLOjYjzVpeBXwUOslL/LdVqtwCfH02FfVuv7keA91d3Sf4S8P2O4ayxs+b61Q2s/DeBleO4OSJeGxGXsnLDyZdL19dyns/jw/N53GXmSB/AtcC3gMPAXaOup4+6LwP+d/V4erV24OdYubvw28D/Al436lq71L6blSGcf2XlWsmt69UNBCt3sB4GvgbMj7r+DY7jf1Z1HmDlBN3Ssf5d1XE8A1wz6vrb+PB8Hkntns8TeD47w5UkSYWNethZkqSpY/hKklSY4StJUmGGryRJhRm+kiQVZvhKklSY4StJUmGGryRJhf1faexlQszjvd8AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 576x864 with 6 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import os,sys\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import helper\n",
    "import simulation\n",
    "\n",
    "# Generate some random images\n",
    "input_images, target_masks = simulation.generate_random_data(192, 192, count=3)\n",
    "\n",
    "print(input_images.shape, target_masks.shape)\n",
    "\n",
    "# Change channel-order and make 3 channels for matplot\n",
    "input_images_rgb = [(x.swapaxes(0, 2).swapaxes(0,1) * -255 + 255).astype(np.uint8) for x in input_images]\n",
    "\n",
    "# Map each channel (i.e. class) to each color\n",
    "target_masks_rgb = [helper.masks_to_colorimg(x) for x in target_masks]\n",
    "\n",
    "# Left: Input image, Right: Target mask\n",
    "helper.plot_side_by_side([input_images_rgb, target_masks_rgb])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False),\n",
       " BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),\n",
       " ReLU(inplace),\n",
       " MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False),\n",
       " Sequential(\n",
       "   (0): BasicBlock(\n",
       "     (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "     (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "     (relu): ReLU(inplace)\n",
       "     (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "     (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "   )\n",
       "   (1): BasicBlock(\n",
       "     (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "     (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "     (relu): ReLU(inplace)\n",
       "     (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "     (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "   )\n",
       " ),\n",
       " Sequential(\n",
       "   (0): BasicBlock(\n",
       "     (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "     (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "     (relu): ReLU(inplace)\n",
       "     (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "     (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "     (downsample): Sequential(\n",
       "       (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "       (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "     )\n",
       "   )\n",
       "   (1): BasicBlock(\n",
       "     (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "     (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "     (relu): ReLU(inplace)\n",
       "     (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "     (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "   )\n",
       " ),\n",
       " Sequential(\n",
       "   (0): BasicBlock(\n",
       "     (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "     (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "     (relu): ReLU(inplace)\n",
       "     (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "     (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "     (downsample): Sequential(\n",
       "       (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "       (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "     )\n",
       "   )\n",
       "   (1): BasicBlock(\n",
       "     (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "     (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "     (relu): ReLU(inplace)\n",
       "     (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "     (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "   )\n",
       " ),\n",
       " Sequential(\n",
       "   (0): BasicBlock(\n",
       "     (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "     (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "     (relu): ReLU(inplace)\n",
       "     (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "     (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "     (downsample): Sequential(\n",
       "       (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "       (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "     )\n",
       "   )\n",
       "   (1): BasicBlock(\n",
       "     (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "     (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "     (relu): ReLU(inplace)\n",
       "     (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "     (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "   )\n",
       " ),\n",
       " AvgPool2d(kernel_size=7, stride=1, padding=0),\n",
       " Linear(in_features=512, out_features=1000, bias=True)]"
      ]
     },
     "execution_count": 87,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from torchvision import models\n",
    "\n",
    "base_model = models.resnet18(pretrained=True)\n",
    "\n",
    "def find_last_layer(layer):\n",
    "    children = list(layer.children())\n",
    "    if len(children) == 0:\n",
    "        return layer\n",
    "    else:\n",
    "        return find_last_layer(children[-1])\n",
    "    \n",
    "list(base_model.children())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'OrderedDict' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-86-e31efe363266>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0mmodel_wo_avgpool\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mSequential\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbase_model\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0mOrderedDict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_wo_avgpool\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnamed_children\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m: name 'OrderedDict' is not defined"
     ]
    }
   ],
   "source": [
    "from torch import nn\n",
    "\n",
    "model_wo_avgpool = nn.Sequential(*list(base_model.children())[:-2])\n",
    "\n",
    "#OrderedDict(model_wo_avgpool.named_children())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "class FCN(nn.Module):\n",
    "\n",
    "    def __init__(self, n_class):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.base_model = models.resnet18(pretrained=True)\n",
    "        \n",
    "        layers = list(base_model.children())\n",
    "        self.layer1 = nn.Sequential(*layers[:5]) # size=(N, 64, x.H/2, x.W/2)\n",
    "        self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear')\n",
    "        self.layer2 = layers[5]  # size=(N, 128, x.H/4, x.W/4)\n",
    "        self.upsample2 = nn.Upsample(scale_factor=8, mode='bilinear')\n",
    "        self.layer3 = layers[6]  # size=(N, 256, x.H/8, x.W/8)\n",
    "        self.upsample3 = nn.Upsample(scale_factor=16, mode='bilinear')\n",
    "        self.layer4 = layers[7]  # size=(N, 512, x.H/16, x.W/16)\n",
    "        self.upsample4 = nn.Upsample(scale_factor=32, mode='bilinear')\n",
    "        \n",
    "        self.conv1k = nn.Conv2d(64 + 128 + 256 + 512, n_class, 1)\n",
    "        self.sigmoid = nn.Sigmoid()\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = self.layer1(x)\n",
    "        up1 = self.upsample1(x)\n",
    "        x = self.layer2(x)\n",
    "        up2 = self.upsample2(x)\n",
    "        x = self.layer3(x)\n",
    "        up3 = self.upsample3(x)\n",
    "        x = self.layer4(x)\n",
    "        up4 = self.upsample4(x)\n",
    "        \n",
    "        merge = torch.cat([up1, up2, up3, up4], dim=1)\n",
    "        merge = self.conv1k(merge)\n",
    "        out = self.sigmoid(merge)\n",
    "        \n",
    "        return out\n",
    "\n",
    "fcn_model = FCN(6)\n",
    "\n",
    "import torchsummary\n",
    "\n",
    "#torchsummary.summary(fcn_model, input_size=(3, 224, 224), device='cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 109,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_model(model, criterion, optimizer, scheduler, num_epochs=25):\n",
    "    since = time.time()\n",
    "\n",
    "    best_model_wts = copy.deepcopy(model.state_dict())\n",
    "    best_loss = 1e10\n",
    "\n",
    "    for epoch in range(num_epochs):\n",
    "        print('Epoch {}/{}'.format(epoch, num_epochs - 1))\n",
    "        print('-' * 10)\n",
    "\n",
    "        # Each epoch has a training and validation phase\n",
    "        for phase in ['train', 'val']:\n",
    "            if phase == 'train':\n",
    "                scheduler.step()\n",
    "                model.train()  # Set model to training mode\n",
    "            else:\n",
    "                model.eval()   # Set model to evaluate mode\n",
    "\n",
    "            running_loss = 0.0\n",
    "            running_corrects = 0\n",
    "\n",
    "            # Iterate over data.\n",
    "            batch_size = 10\n",
    "            epoch_steps = 10\n",
    "            for i in range(epoch_steps):\n",
    "                input_images, target_masks = simulation.generate_random_data(192, 192, count=batch_size)\n",
    "\n",
    "                inputs = torch.from_numpy(input_images)\n",
    "                labels = torch.from_numpy(target_masks)\n",
    "                inputs = inputs.to(device)\n",
    "                labels = labels.to(device)                \n",
    "\n",
    "                # zero the parameter gradients\n",
    "                optimizer.zero_grad()\n",
    "\n",
    "                # forward\n",
    "                # track history if only in train\n",
    "                with torch.set_grad_enabled(phase == 'train'):\n",
    "                    outputs = model(inputs)\n",
    "                    loss = criterion(outputs, labels)\n",
    "\n",
    "                    # backward + optimize only if in training phase\n",
    "                    if phase == 'train':\n",
    "                        loss.backward()\n",
    "                        optimizer.step()\n",
    "\n",
    "                # statistics\n",
    "                running_loss += loss.item() * inputs.size(0)\n",
    "\n",
    "            epoch_loss = running_loss / (batch_size * epoch_steps)\n",
    "\n",
    "            print('{} Loss: {:.4f}'.format(\n",
    "                phase, epoch_loss))\n",
    "\n",
    "            # deep copy the model\n",
    "            if phase == 'val' and epoch_loss < best_loss:\n",
    "                best_loss = epoch_loss\n",
    "                best_model_wts = copy.deepcopy(model.state_dict())\n",
    "\n",
    "        print()\n",
    "\n",
    "    time_elapsed = time.time() - since\n",
    "    print('Training complete in {:.0f}m {:.0f}s'.format(\n",
    "        time_elapsed // 60, time_elapsed % 60))\n",
    "    print('Best val loss: {:4f}'.format(best_loss))\n",
    "\n",
    "    # load best model weights\n",
    "    model.load_state_dict(best_model_wts)\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 110,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0/24\n",
      "----------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/user/miniconda/envs/py36/lib/python3.6/site-packages/torch/nn/functional.py:1749: UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.\n",
      "  \"See the documentation of nn.Upsample for details.\".format(mode))\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train Loss: 0.4876\n",
      "val Loss: 0.2639\n",
      "\n",
      "Epoch 1/24\n",
      "----------\n",
      "train Loss: 0.1647\n",
      "val Loss: 0.0984\n",
      "\n",
      "Epoch 2/24\n",
      "----------\n",
      "train Loss: 0.0781\n",
      "val Loss: 0.0628\n",
      "\n",
      "Epoch 3/24\n",
      "----------\n",
      "train Loss: 0.0570\n",
      "val Loss: 0.0539\n",
      "\n",
      "Epoch 4/24\n",
      "----------\n",
      "train Loss: 0.0469\n",
      "val Loss: 0.0464\n",
      "\n",
      "Epoch 5/24\n",
      "----------\n",
      "train Loss: 0.0443\n",
      "val Loss: 0.0441\n",
      "\n",
      "Epoch 6/24\n",
      "----------\n",
      "train Loss: 0.0420\n",
      "val Loss: 0.0425\n",
      "\n",
      "Epoch 7/24\n",
      "----------\n",
      "train Loss: 0.0416\n",
      "val Loss: 0.0417\n",
      "\n",
      "Epoch 8/24\n",
      "----------\n",
      "train Loss: 0.0407\n",
      "val Loss: 0.0419\n",
      "\n",
      "Epoch 9/24\n",
      "----------\n",
      "train Loss: 0.0411\n",
      "val Loss: 0.0413\n",
      "\n",
      "Epoch 10/24\n",
      "----------\n",
      "train Loss: 0.0405\n",
      "val Loss: 0.0397\n",
      "\n",
      "Epoch 11/24\n",
      "----------\n",
      "train Loss: 0.0402\n",
      "val Loss: 0.0392\n",
      "\n",
      "Epoch 12/24\n",
      "----------\n",
      "train Loss: 0.0406\n",
      "val Loss: 0.0395\n",
      "\n",
      "Epoch 13/24\n",
      "----------\n",
      "train Loss: 0.0402\n",
      "val Loss: 0.0396\n",
      "\n",
      "Epoch 14/24\n",
      "----------\n",
      "train Loss: 0.0405\n",
      "val Loss: 0.0406\n",
      "\n",
      "Epoch 15/24\n",
      "----------\n",
      "train Loss: 0.0411\n",
      "val Loss: 0.0407\n",
      "\n",
      "Epoch 16/24\n",
      "----------\n",
      "train Loss: 0.0406\n",
      "val Loss: 0.0411\n",
      "\n",
      "Epoch 17/24\n",
      "----------\n",
      "train Loss: 0.0410\n",
      "val Loss: 0.0401\n",
      "\n",
      "Epoch 18/24\n",
      "----------\n",
      "train Loss: 0.0396\n",
      "val Loss: 0.0398\n",
      "\n",
      "Epoch 19/24\n",
      "----------\n",
      "train Loss: 0.0403\n",
      "val Loss: 0.0403\n",
      "\n",
      "Epoch 20/24\n",
      "----------\n",
      "train Loss: 0.0410\n",
      "val Loss: 0.0415\n",
      "\n",
      "Epoch 21/24\n",
      "----------\n",
      "train Loss: 0.0402\n",
      "val Loss: 0.0402\n",
      "\n",
      "Epoch 22/24\n",
      "----------\n",
      "train Loss: 0.0393\n",
      "val Loss: 0.0398\n",
      "\n",
      "Epoch 23/24\n",
      "----------\n",
      "train Loss: 0.0405\n",
      "val Loss: 0.0398\n",
      "\n",
      "Epoch 24/24\n",
      "----------\n",
      "train Loss: 0.0409\n",
      "val Loss: 0.0410\n",
      "\n",
      "Training complete in 1m 11s\n",
      "Best val loss: 0.039217\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.optim import lr_scheduler\n",
    "import time\n",
    "import copy\n",
    "\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "model_ft = FCN(6).to(device)\n",
    "\n",
    "criterion = nn.BCELoss()\n",
    "\n",
    "# Observe that all parameters are being optimized\n",
    "optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)\n",
    "\n",
    "# Decay LR by a factor of 0.1 every 7 epochs\n",
    "exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)\n",
    "\n",
    "model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=25)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:py36]",
   "language": "python",
   "name": "conda-env-py36-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
